{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nifti Read Example\n",
    "\n",
    "The purpose of this notebook is to illustrate reading Nifti files and iterating over patches of the volumes loaded from them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0.1a1.dev8+6.gb3c5761.dirty\n",
      "Python version: 3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31)  [GCC 7.3.0]\n",
      "Numpy version: 1.18.1\n",
      "Pytorch version: 1.4.0\n",
      "Ignite version: 0.3.0\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import sys\n",
    "from glob import glob\n",
    "import tempfile\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import nibabel as nib\n",
    "\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import monai\n",
    "from monai.data import NiftiDataset, GridPatchDataset, create_test_image_3d\n",
    "from monai.transforms import Compose, AddChannel, Transpose, ScaleIntensity, ToTensor, RandSpatialCrop\n",
    "\n",
    "monai.config.print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a number of test Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tempdir = tempfile.mkdtemp()\n",
    "\n",
    "for i in range(5):\n",
    "    im, seg = create_test_image_3d(128, 128, 128)\n",
    "    \n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz'%i))\n",
    "    \n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz'%i))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a data loader which yields uniform random patches from loaded Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 1, 64, 64, 64]) torch.Size([5, 1, 64, 64, 64])\n"
     ]
    }
   ],
   "source": [
    "images = sorted(glob(os.path.join(tempdir,'im*.nii.gz')))\n",
    "segs = sorted(glob(os.path.join(tempdir,'seg*.nii.gz')))\n",
    "\n",
    "imtrans = Compose([\n",
    "    ScaleIntensity(),\n",
    "    AddChannel(),\n",
    "    RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "    ToTensor()\n",
    "])    \n",
    "\n",
    "segtrans = Compose([\n",
    "    AddChannel(),\n",
    "    RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "    ToTensor()\n",
    "])    \n",
    "    \n",
    "ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans)\n",
    "\n",
    "loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())\n",
    "im, seg = monai.utils.misc.first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively create a data loader which yields patches in regular grid order from loaded images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n"
     ]
    }
   ],
   "source": [
    "imtrans = Compose([\n",
    "    ScaleIntensity(),\n",
    "    AddChannel(),\n",
    "    ToTensor()\n",
    "])    \n",
    "\n",
    "segtrans = Compose([\n",
    "    AddChannel(),\n",
    "    ToTensor()\n",
    "])    \n",
    "    \n",
    "ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans)\n",
    "ds = GridPatchDataset(ds, (64, 64, 64))\n",
    "\n",
    "loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())\n",
    "im, seg = monai.utils.misc.first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf {tempdir}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}